import logging
import torch
import random
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import random
from copy import deepcopy
from collections import deque
import time




def WFA(s_init, s_current, h_r, k, sp, L = -10000):
    # create a new graph
    logging = False
    G_dash = nx.DiGraph()

    ##############
    # ADD NODES
    ##############
    # add source and sink
    G_dash.add_node('source')
    G_dash.add_node('sink')
    for i in range(len(h_r)):
        G_dash.add_node(str(i))
        G_dash.add_node(str(i) + '_dash')

    # add the initial locations of the servers.
    for i in s_init:
        G_dash.add_node('s_start_' + str(i))

    # add the current location of the servers
    # the k-th server (for which we are computing WFA) should be at the current request
    s_current_c = deepcopy(s_current)
    s_current_c[k] = h_r[-1]
    for i in s_current_c:
        G_dash.add_node('s_curr_' + str(i))

    ##############
    # ADD EDGES
    ##############

    # first, connect source to initial server locations
    for i in s_init:
        G_dash.add_edges_from([
            ('source', 's_start_' + str(i), {"capacity": 1, "weight": 0}),
        ])

    # connect each server to each request: this only potentially enables service,
    # we are not saying that each server goes to each request
    for i in s_init:
        for j in range(len(h_r)):
            G_dash.add_edges_from([
                ('s_start_' + str(i), str(j), {"capacity": 1, "weight": sp[i][h_r[j]]})
            ])

    # connect r and r'
    for i in range(len(h_r)):
        G_dash.add_edges_from([
            (str(i), str(i) + '_dash', {"capacity": 1, "weight": L})
        ])

    # connect each r'_t1 to r'_t2 if t1<t2, i.e., if t1 arrived before t2
    for i in range(len(h_r)):
        for j in range(i + 1, len(h_r)):
            G_dash.add_edges_from([
                (str(i) + '_dash', str(j), {"capacity": 1, "weight": sp[h_r[i]][h_r[j]]})
            ])

    # connect each r' to the current configuration of the servers
    for i in range(len(h_r)):
        for j in s_current_c:
            G_dash.add_edges_from([
                (str(i) + '_dash', 's_curr_' + str(j), {"capacity": 1, "weight": sp[h_r[i]][j]})
            ])

    # connect initial server locations to current server locations:
    for i in s_init:
        for j in s_current_c:
            G_dash.add_edges_from([
                ('s_start_' + str(i), 's_curr_' + str(j), {"capacity": 1, "weight": sp[i][j]})
            ])

    # connect current server locations to sink
    for j in s_current_c:
        G_dash.add_edges_from([
            ('s_curr_' + str(j), 'sink', {"capacity": 1, "weight": 0})
        ])
    # return the offline opt term + the online opt term
    if logging:
        print("Offline Cost is: {}".format(
            nx.cost_of_flow(G_dash, nx.max_flow_min_cost(G_dash, 'source', 'sink')) - (len(h_r) * L)))
    return nx.cost_of_flow(G_dash, nx.max_flow_min_cost(G_dash, 'source', 'sink')) - (len(h_r) * L) + sp[s_current[k]][
        h_r[-1]]

def opt_off(s_init, h_r, sp, L = -10000):

    logging = False
    # create a new graph
    G_dash = nx.DiGraph()

    ##############
    # ADD NODES
    ##############
    # add source and sink
    G_dash.add_node('source')
    G_dash.add_node('sink')
    for i in range(len(h_r)):
        G_dash.add_node(str(i))
        G_dash.add_node(str(i) + '_dash')

    # add the initial locations of the servers.
    for i in s_init:
        G_dash.add_node('s_start_' + str(i))

    ##############
    # ADD EDGES
    ##############

    # first, connect source to initial server locations
    for i in s_init:
        G_dash.add_edges_from([
            ('source', 's_start_' + str(i), {"capacity": 1, "weight": 0}),
        ])

    # connect each server to each request: this only potentially enables service,
    # we are not saying that each server goes to each request
    for i in s_init:
        for j in range(len(h_r)):
            G_dash.add_edges_from([
                ('s_start_' + str(i), str(j), {"capacity": 1, "weight": sp[i][h_r[j]]})
            ])

    # connect r and r'
    for i in range(len(h_r)):
        G_dash.add_edges_from([
            (str(i), str(i) + '_dash', {"capacity": 1, "weight": L})
        ])

    # connect each r'_t1 to r'_t2 if t1<t2, i.e., if t1 arrived before t2
    for i in range(len(h_r)):
        for j in range(i + 1, len(h_r)):
            G_dash.add_edges_from([
                (str(i) + '_dash', str(j), {"capacity": 1, "weight": sp[h_r[i]][h_r[j]]})
            ])

    # connect each r' to sink
    for i in range(len(h_r)):
        G_dash.add_edges_from([
            (str(i) + '_dash', 'sink', {"capacity": 1, "weight": 0})
        ])

    # connect each s_init to sink
    for i in s_init:
        G_dash.add_edges_from([
            ('s_start_' + str(i), 'sink', {"capacity": 1, "weight": 0})
        ])

    # return the offline opt term + the online opt term
    if logging:
        print("Offline Cost is: {}".format(
            nx.cost_of_flow(G_dash, nx.max_flow_min_cost(G_dash, 'source', 'sink')) - (len(h_r) * L)))
    return nx.cost_of_flow(G_dash, nx.max_flow_min_cost(G_dash, 'source', 'sink')) - (len(h_r) * L)

class WorkFunction():
  def __init__(self, env, num_requests = 100):
    self.env = env
    self.batch_size = env.batch_size # has to be 1 for Qtable
    self.num_requests = num_requests
    if self.batch_size > 1:
            raise ValueError('The environment batch size has to be equal to 1') 
    self.device = self.env.device 
    self.total_reward = torch.empty((self.env.num_servers)).to(self.device) 
    self.requests = deque(maxlen=num_requests)
    self.server_init = sorted(random.sample(range(self.env.num_nodes), self.env.num_servers))
    
    # Initialize the first element and add it to the deque
    first_request = random.randint(0, self.env.num_nodes - 1)
    while first_request in self.server_init:
        first_request = random.randint(0, self.env.num_nodes - 1)
    self.requests.append(first_request)
    
    # Add random requests to the deque, ensuring consecutive elements are not identical
    for _ in range(num_requests-1):  # We've already added the first request
        while True:
            new_request = random.randint(0, self.env.num_nodes - 1)
            if new_request != self.requests[-1]:
                break
        self.requests.append(new_request)


    self.sp = dict(nx.all_pairs_shortest_path_length(self.env.graph))


  def estimate(self, num_steps=1,  print_results = False):

    
    steps_for_display = int(10000/self.batch_size)
    num_steps = int(num_steps*1000/self.batch_size)

    server_current = self.server_init.copy()
    # to run the WFA, we need to test the flow for each k
    start_time = time.time()
    for r in range(0, len(self.requests)-1):
        min_cost = 1e10
        best_k = 0
        for k in range(len(self.server_init)):
            k_cost = WFA(s_init=self.server_init, s_current=server_current, h_r=list(self.requests)[:r + 1], k=k, sp = self.sp)
            if k_cost < min_cost:
                min_cost = k_cost
                best_k = k
        
        server_current[best_k] = self.requests[r]
    end_time = time.time()
    elapsed_time = end_time - start_time

    # # server_current = []
    # # server_current.append(self.requests[-2])
    # # for i in range(self.env.num_servers-1):
    # #     while True:
    # #         new_request = random.randint(0, self.env.num_nodes - 1)
    # #         if new_request not in server_current and new_request != self.requests[-1]:
    # #                 break
    # #     server_current.append(new_request)
    
    
    # print(f"It took took {round(elapsed_time, 3)} seconds to consider first {self.num_requests} requests")
    

    my_tensor = torch.tensor(sorted(self.server_init))
    state = torch.cat((my_tensor, torch.tensor([self.requests[0]]))).unsqueeze(0).to(self.device)
    start_time = time.time()
    actions_all = state[:, :-1]
    # choosing which server to send from current servers
    server_current = actions_all.squeeze(0).tolist()
    
    for r in range(0, len(self.requests)):

        min_cost = 1e10
        best_k_current = 0    
        for k in range(len(self.server_init)):
            k_cost = WFA(s_init=self.server_init, s_current=server_current, h_r=list(self.requests), k=k, sp = self.sp)
            if k_cost < min_cost:
                min_cost = k_cost
                best_k_current = k
        # changing initial server locations 
        server_init_prev = self.server_init.copy()
        min_cost = 1e10
        best_k = 0
        for k in range(len(self.server_init)):
            k_cost = WFA(s_init= server_init_prev, s_current= self.server_init, h_r=list(self.requests)[:1], k=k, sp = self.sp)
            if k_cost < min_cost:
                min_cost = k_cost
                best_k = k
        self.server_init[best_k] = self.requests[0]
        
        action = actions_all[:, best_k_current]
        next_state, reward, _ = self.env.step(action.unsqueeze(0), state)
        self.requests.append(next_state[:, -1].item())
        
        self.total_reward = torch.cat((self.total_reward, reward.to(self.device) ), 0)
        state = next_state.to(self.device) 
        if print_results:
            if ((step+1)  % steps_for_display == 0):
                print(f"Step {step+1},  Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}")


    if print_results:
        print(f"Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}")

    estimates = self.total_reward[self.env.num_servers:]

    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"Overall, it took took {round(elapsed_time, 3)}")

    return torch.mean(estimates), torch.quantile(estimates, 0.25), torch.quantile(estimates, 0.75), estimates

